# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for DPR."""


import collections
import logging
from typing import List, Optional, Union

from .file_utils import add_end_docstrings, add_start_docstrings
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .tokenization_utils_base import BatchEncoding, TensorType


logger = logging.getLogger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
    }
}
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
    }
}
READER_PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
    }
}

CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "facebook/dpr-ctx_encoder-single-nq-base": 512,
}
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "facebook/dpr-question_encoder-single-nq-base": 512,
}
READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "facebook/dpr-reader-single-nq-base": 512,
}


CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
    "facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True},
}
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
    "facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True},
}
READER_PRETRAINED_INIT_CONFIGURATION = {
    "facebook/dpr-reader-single-nq-base": {"do_lower_case": True},
}


class DPRContextEncoderTokenizer(BertTokenizer):
    r"""
    Constructs a  DPRContextEncoderTokenizer.

    :class:`~transformers.DPRContextEncoderTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
    tokenization: punctuation splitting + wordpiece.

    Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
    parameters.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION


class DPRContextEncoderTokenizerFast(BertTokenizerFast):
    r"""
    Constructs a  "Fast" DPRContextEncoderTokenizer (backed by HuggingFace's `tokenizers` library).

    :class:`~transformers.DPRContextEncoderTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
    tokenization: punctuation splitting + wordpiece.

    Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
    parameters.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION


class DPRQuestionEncoderTokenizer(BertTokenizer):
    r"""
    Constructs a  DPRQuestionEncoderTokenizer.

    :class:`~transformers.DPRQuestionEncoderTokenizer` is identical to :class:`~transformers.BertTokenizer` and runs end-to-end
    tokenization: punctuation splitting + wordpiece.

    Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
    parameters.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION


class DPRQuestionEncoderTokenizerFast(BertTokenizerFast):
    r"""
    Constructs a  "Fast" DPRQuestionEncoderTokenizer (backed by HuggingFace's `tokenizers` library).

    :class:`~transformers.DPRQuestionEncoderTokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
    tokenization: punctuation splitting + wordpiece.

    Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
    parameters.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION


DPRSpanPrediction = collections.namedtuple(
    "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
)

DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"])


CUSTOM_DPR_READER_DOCSTRING = r"""
    Return a dictionary with the token ids of the input strings and other information to give to :obj:`.decode_best_spans`.
    It converts the strings of a question and different passages (title + text) in a sequence of ids (integer), using the tokenizer and vocabulary.
    The resulting `input_ids` is a matrix of size :obj:`(n_passages, sequence_length)` with the format:

        [CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>

    Inputs:
        questions (:obj:`str`, :obj:`List[str]`):
            The questions to be encoded.
            You can specify one question for many passages. In this case, the question will be duplicated like :obj:`[questions] * n_passages`.
            Otherwise you have to specify as many questions as in :obj:`titles` or :obj:`texts`.
        titles (:obj:`str`, :obj:`List[str]`):
            The passages titles to be encoded. This can be a string, a list of strings if there are several passages.
        texts (:obj:`str`, :obj:`List[str]`):
            The passages texts to be encoded. This can be a string, a list of strings if there are several passages.
        padding (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
            Activate and control padding. Accepts the following values:

            * `True` or `'longest'`: pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
            * `'max_length'`: pad to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`)
            * `False` or `'do_not_pad'` (default): No padding (i.e. can output batch with sequences of uneven lengths)
        truncation (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
            Activate and control truncation. Accepts the following values:

            * `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`).
            * `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
        max_length (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`):
            Control the length for padding/truncation. Accepts the following values

            * `None` (default): This will use the predefined model max length if required by one of the truncation/padding parameters. If the model has no specific max input length (e.g. XLNet) truncation/padding to max length is deactivated.
            * `any integer value` (e.g. `42`): Use this specific maximum length value if required by one of the truncation/padding parameters.
        return_tensors (:obj:`str`, `optional`, defaults to :obj:`None`):
            Can be set to 'tf', 'pt' or 'np' to return respectively TensorFlow :obj:`tf.constant`,
            PyTorch :obj:`torch.Tensor` or Numpy :obj: `np.ndarray` instead of a list of python integers.
        return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`none`):
            Whether to return the attention mask. If left to the default, will return the attention mask according
            to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.

            `What are attention masks? <../glossary.html#attention-mask>`__

    Return:
        A Dictionary of shape::

            {
                input_ids: list[list[int]],
                attention_mask: list[int] if return_attention_mask is True (default)
            }

        With the fields:

        - ``input_ids``: list of token ids to be fed to a model
        - ``attention_mask``: list of indices specifying which tokens should be attended to by the model

        """


@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING)
class CustomDPRReaderTokenizerMixin:
    def __call__(
        self,
        questions,
        titles: Optional[str] = None,
        texts: Optional[str] = None,
        padding: Union[bool, str] = False,
        truncation: Union[bool, str] = False,
        max_length: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_attention_mask: Optional[bool] = None,
        **kwargs
    ) -> BatchEncoding:
        if titles is None and texts is None:
            return super().__call__(
                questions,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                return_tensors=return_tensors,
                return_attention_mask=return_attention_mask,
                **kwargs,
            )
        elif titles is None or texts is None:
            text_pair = titles if texts is None else texts
            return super().__call__(
                questions,
                text_pair,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                return_tensors=return_tensors,
                return_attention_mask=return_attention_mask,
                **kwargs,
            )
        titles = titles if not isinstance(titles, str) else [titles]
        texts = texts if not isinstance(texts, str) else [texts]
        n_passages = len(titles)
        questions = questions if not isinstance(questions, str) else [questions] * n_passages
        assert len(titles) == len(
            texts
        ), "There should be as many titles than texts but got {} titles and {} texts.".format(len(titles), len(texts))
        encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"]
        encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"]
        encoded_inputs = {
            "input_ids": [
                (encoded_question_and_title + encoded_text)[:max_length]
                if max_length is not None and truncation
                else encoded_question_and_title + encoded_text
                for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts)
            ]
        }
        if return_attention_mask is not False:
            attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]
            encoded_inputs["attention_mask"] = attention_mask
        return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)

    def decode_best_spans(
        self,
        reader_input: BatchEncoding,
        reader_output: DPRReaderOutput,
        num_spans: int = 16,
        max_answer_length: int = 64,
        num_spans_per_passage: int = 4,
    ) -> List[DPRSpanPrediction]:
        """
        Get the span predictions for the extractive Q&A model.
        Outputs: `List` of `DPRReaderOutput` sorted by descending `(relevance_score, span_score)`.
            Each `DPRReaderOutput` is a `Tuple` with:
            **span_score**: ``float`` that corresponds to the score given by the reader for this span compared to other spans
                in the same passage. It corresponds to the sum of the start and end logits of the span.
            **relevance_score**: ``float`` that corresponds to the score of the each passage to answer the question,
                compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader.
            **doc_id**: ``int``` the id of the passage.
            **start_index**: ``int`` the start index of the span (inclusive).
            **end_index**: ``int`` the end index of the span (inclusive).

        Examples::

            from transformers import DPRReader, DPRReaderTokenizer
            tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
            model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base')
            encoded_inputs = tokenizer(
                    questions=["What is love ?"],
                    titles=["Haddaway"],
                    texts=["'What Is Love' is a song recorded by the artist Haddaway"],
                    return_tensors='pt'
                )
            outputs = model(**encoded_inputs)
            predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs)
            print(predicted_spans[0].text)  # best span

        """
        input_ids = reader_input["input_ids"]
        start_logits, end_logits, relevance_logits = reader_output[:3]
        n_passages = len(relevance_logits)
        sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__)
        nbest_spans_predictions: List[DPRReaderOutput] = []
        for doc_id in sorted_docs:
            sequence_ids = list(input_ids[doc_id])
            # assuming question & title information is at the beginning of the sequence
            passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1  # second sep id
            if sequence_ids[-1] == self.pad_token_id:
                sequence_len = sequence_ids.index(self.pad_token_id)
            else:
                sequence_len = len(sequence_ids)

            best_spans = self._get_best_spans(
                start_logits=start_logits[doc_id][passage_offset:sequence_len],
                end_logits=end_logits[doc_id][passage_offset:sequence_len],
                max_answer_length=max_answer_length,
                top_spans=num_spans_per_passage,
            )
            for start_index, end_index in best_spans:
                start_index += passage_offset
                end_index += passage_offset
                nbest_spans_predictions.append(
                    DPRSpanPrediction(
                        span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index],
                        relevance_score=relevance_logits[doc_id],
                        doc_id=doc_id,
                        start_index=start_index,
                        end_index=end_index,
                        text=self.decode(sequence_ids[start_index : end_index + 1]),
                    )
                )
            if len(nbest_spans_predictions) >= num_spans:
                break
        return nbest_spans_predictions[:num_spans]

    def _get_best_spans(
        self, start_logits: List[int], end_logits: List[int], max_answer_length: int, top_spans: int,
    ) -> List[DPRSpanPrediction]:
        """
        Finds the best answer span for the extractive Q&A model for one passage.
        It returns the best span by descending `span_score` order and keeping max `top_spans` spans.
        Spans longer that `max_answer_length` are ignored.
        """
        scores = []
        for (start_index, start_score) in enumerate(start_logits):
            for (answer_length, end_score) in enumerate(end_logits[start_index : start_index + max_answer_length]):
                scores.append(((start_index, start_index + answer_length), start_score + end_score))
        scores = sorted(scores, key=lambda x: x[1], reverse=True)
        chosen_span_intervals = []
        for (start_index, end_index), score in scores:
            assert start_index <= end_index, "Wrong span indices: [{}:{}]".format(start_index, end_index)
            length = end_index - start_index + 1
            assert length <= max_answer_length, "Span is too long: {} > {}".format(length, max_answer_length)
            if any(
                [
                    start_index <= prev_start_index <= prev_end_index <= end_index
                    or prev_start_index <= start_index <= end_index <= prev_end_index
                    for (prev_start_index, prev_end_index) in chosen_span_intervals
                ]
            ):
                continue
            chosen_span_intervals.append((start_index, end_index))

            if len(chosen_span_intervals) == top_spans:
                break
        return chosen_span_intervals


@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
    r"""
    Constructs a  DPRReaderTokenizer.

    :class:`~transformers.DPRReaderTokenizer` is alsmost identical to :class:`~transformers.BertTokenizer` and runs end-to-end
    tokenization: punctuation splitting + wordpiece.

    What is different is that is has three inputs strings: question, titles and texts that are combined to feed into the DPRReader model.

    Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
    parameters.

    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
    model_input_names = ["attention_mask"]


@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast):
    r"""
    Constructs a  DPRReaderTokenizerFast.

    :class:`~transformers.DPRReaderTokenizerFast` is almost identical to :class:`~transformers.BertTokenizerFast` and runs end-to-end
    tokenization: punctuation splitting + wordpiece.

    What is different is that is has three inputs strings: question, titles and texts that are combined to feed into the DPRReader model.

    Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
    parameters.

    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = READER_PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    pretrained_init_configuration = READER_PRETRAINED_INIT_CONFIGURATION
    model_input_names = ["attention_mask"]
